Classification with KNN and Logistic Regression

Classification with K-Nearest-Neighbors

KNN for Categorical Response

We have existing observations

\[(x_1, C_1), ... (x_n, C_n)\]

where the \(C_i\) are categories.



Given a new observation \(x_{new}\), how do we predict \(C_{new}\)?

Predicting a Response with KNN

  1. Find the \(n\) (e.g., 5) values in \((x_1, ..., x_n)\) that are closest to \(x_{new}\)
  1. Let all the closest neighbors “vote” on the category.
  1. Predict \(\widehat{C}_{new}\) = the category with the most votes.

KNN for Classification

To perform classification with K-Nearest-Neighbors, we choose the K closest observations to our target, and we aggregate their response values.



The Big Questions:

  • What is our definition of closest?

  • What number should we use for K?

Example

Let’s keep hanging out with the insurance dataset.

Suppose we want to use information about insurance charges to predict whether someone is a smoker or not.

ins
# A tibble: 431 × 6
     age sex      bmi smoker region    charges
   <dbl> <chr>  <dbl> <chr>  <chr>       <dbl>
 1    19 female  27.9 yes    southwest  16885.
 2    33 male    22.7 no     northwest  21984.
 3    32 male    28.9 no     northwest   3867.
 4    31 female  25.7 no     southeast   3757.
 5    60 female  25.8 no     northwest  28923.
 6    25 male    26.2 no     northeast   2721.
 7    62 female  26.3 yes    southeast  27809.
 8    56 female  39.8 no     southeast  11091.
 9    27 male    42.1 yes    southeast  39612.
10    23 male    23.8 no     northeast   2395.
# ℹ 421 more rows

Step 1: Establish the Model

knn_mod <- nearest_neighbor(neighbors = 5) %>%
  set_engine("kknn") %>%
  set_mode("classification")

Notice:

  • New mode - "classification"

  • Everything else is the same!

Step 2: Fit our Model

knn_fit_1 <- knn_mod %>%
  fit(smoker ~ charges, data = ins)
Error in `check_outcome()`:
! For a classification model, the outcome should be a `factor`, not a `character`.




What should we do???

Step 2: Transform our Response

ins <- ins %>%
  mutate(
    smoker = as.factor(smoker)
  ) %>%
  drop_na(smoker)

Step 3: (Re)Fit our Model

knn_fit_1 <- knn_mod %>%
  fit(smoker ~ charges, data = ins)

Heck yeah!

knn_fit_1$fit %>% summary()

Call:
kknn::train.kknn(formula = smoker ~ charges, data = data, ks = min_rows(5,     data, 5))

Type of response variable: nominal
Minimal misclassification: 0.06032
Best kernel: optimal
Best k: 5

Try it!

Open Activity-Classification.qmd.

Select the best KNN model for predicting smoker status.

What metrics does the cross-validation process automatically output?

Logistic Regression

Ordinary linear regression classification?

lm_mod <- linear_reg() %>%
  set_engine("lm") %>%
  set_mode("classification")
Error in `set_mode()`:
! 'classification' is not a known mode for model `linear_reg()`.

Ordinary linear regression with a dummy variable

Consider the following idea:

Convert the smoker variable to a dummy variable:

ins <- ins %>%
  mutate(
    smoker_number = case_when(
      smoker == "yes" ~ 1,
      smoker == "no" ~ 0
    )
  )

Ordinary linear regression with a dummy variable

Fit a linear regression predicting smoker dummy var:

lm_mod <- linear_reg() %>%
  set_engine("lm") %>%
  set_mode("regression")

ins_rec <- recipe(smoker_number ~ charges, data = ins) %>%
  step_normalize(all_numeric_predictors())

ins_wflow <- workflow() %>%
  add_recipe(ins_rec) %>%
  add_model(lm_mod)

ins_fit <- ins_wflow %>%
  fit(ins)

Ordinary linear regression with a dummy variable

Predict each observation to be the smoker closest to the number:

preds <- ins_fit %>% predict(ins)

ins <- ins %>%
  mutate(
    predicted_num = preds$.pred,
    predicted_smoker = round(predicted_num)
  )

Ordinary linear regression

How did we do?

ins %>%
  count(predicted_smoker, smoker_number)
# A tibble: 4 × 3
  predicted_smoker smoker_number     n
             <dbl>         <dbl> <int>
1                0             0   336
2                0             1    28
3                1             0     8
4                1             1    59

## {background-color=“#B6CADA”}

What’s wrong with this?

Residuals

Linear regression assumes that the residuals are Normally distributed. Obviously, they are not here.

Logistic Regression

Solution: How about the same approach, Y is a function of X plus noise, but we let the noise be non-Normal?




\[Y = g^{-1}(\beta_0 + \beta_1 X + \epsilon) \]

for some function \(g\).

Logistic Regression

Easier way to think of it:

Before:

\[\mu_Y = \beta_0 + \beta_1 X\]

Now:

\[g(\mu_Y) = \beta_0 + \beta_1 X\]

\(g\) is called the link function

Logistic Regression

A common link function is logit function:

\[g(u) = \frac{log(u)}{log(1-u)}\]

In this case, \(u\) represents the probability of someone being a smoker.

Our observations \(Y\) have probability 0 or 1, since we observe them.

Future observations are unknown, so we predict them.

Logistic Regression

In summary:

  • Given predictors, we try to predict the log-odds of a person being a smoker.

  • We assume random noise on the relationship between the predictors and the log-odds of the response

  • From these log-odds, we calculate the probabilities.

  • We compare the probabilities (between 0 and 1) to the observed truths (0 or 1 exactly).

Step 1: Establish the Model

New model:

logit_mod <- logistic_reg() %>%
  set_mode("classification") %>%
  set_engine("glm")

Step 2: Make a Recipe

Same recipe but sticking with the original (untransformed) smoker variable now:

ins_rec <- recipe(smoker ~ charges, data = ins) %>%
  step_normalize(all_numeric_predictors())

Step 2: Fit our Model

New workflow:

ins_wflow_logit <- workflow() %>%
  add_recipe(ins_rec) %>%
  add_model(logit_mod)

ins_fit <- ins_wflow_logit %>%
  fit(ins)

ins_fit %>% pull_workflow_fit()
parsnip model object


Call:  stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)

Coefficients:
(Intercept)      charges  
      -2.62         3.62  

Degrees of Freedom: 430 Total (i.e. Null);  429 Residual
Null Deviance:      434 
Residual Deviance: 139  AIC: 143

Step 3: Get our Predictions

Notice: Now our predictions are of the type .pred_class! R did the hard part for us.

preds <- predict(ins_fit, new_data = ins)
preds
# A tibble: 431 × 1
   .pred_class
   <fct>      
 1 no         
 2 yes        
 3 no         
 4 no         
 5 yes        
 6 no         
 7 yes        
 8 no         
 9 yes        
10 no         
# ℹ 421 more rows

Log-Odds Predictions

Suppose we wanted to see the predicted log-odds values:

predict(ins_fit, new_data = ins, type = "raw")
        1         2         3         4         5         6         7         8 
-1.225256  0.328346 -5.191275 -5.224858  2.442245 -5.540267  2.102734 -2.990489 
        9        10        11        12        13        14        15        16 
 5.698586 -5.639630  4.853385 -2.339098  4.471990 -5.699642 -5.667924  8.306921 
       17        18        19        20        21        22        23        24 
-5.441333 -0.084214 -5.285697  5.129131 -1.887326 -5.838252 -1.318829 -2.909468 
       25        26        27        28        29        30        31        32 
-5.902457 -5.530000 -4.367436 -3.951805  6.907207 -2.995833 -2.992969 -5.751803 
       33        34        35        36        37        38        39        40 
-3.253843  0.458793 -1.549477 -4.484697 -5.258837  0.133482 -5.495849 -2.901954 
       41        42        43        44        45        46        47        48 
-5.849704 -5.681580 -2.465043 -5.712592 -5.985026 -5.746103 -5.709225 -3.996180 
       49        50        51        52        53        54        55        56 
-0.286973 -4.252375  0.096566  4.887747  4.643787 -2.285313 -5.853003 -5.499408 
       57        58        59        60        61        62        63        64 
 8.505278 -4.881437 -5.718084 -6.022795 -5.869830 -4.648117 -3.431162 -1.829511 
       65        66        67        68        69        70        71        72 
-5.720210 -2.405098  4.943694 -4.191930 -2.791677 -2.140432 -3.214434  1.271953 
       73        74        75        76        77        78        79        80 
-3.150687  4.226450 -0.423825 -5.549963  1.011647 -5.842953 -5.877080  2.625047 
       81        82        83        84        85        86        87        88 
 1.087549 -2.515841 -5.764002  6.151122 -4.723896  3.912862 -2.763082  4.665089 
       89        90        91        92        93        94        95        96 
-5.600635 -5.714180 -3.436747 -5.402960 -5.712266 -5.850026 -1.404112 -2.730351 
       97        98        99       100       101       102       103       104 
-5.840031 -3.894190  4.244503 -3.190237 -4.108428  4.443372 -2.845601  1.062306 
      105       106       107       108       109       110       111       112 
-2.278109 -2.670035 -0.594530 -2.301548 -2.109691 -5.870997 -3.648128 -5.286526 
      113       114       115       116       117       118       119       120 
-2.590123 -2.063631  1.126109 -5.809594 -5.879595 -2.151534 -2.282374 -2.704550 
      121       122       123       124       125       126       127       128 
-5.945395 -0.875194  6.518697 -0.038148 -5.984899 -2.757217 -5.401498 -5.717957 
      129       130       131       132       133       134       135       136 
-4.076283 -1.316173 -5.872235 -3.546765 -1.893137 -5.376095 -5.875382 -3.584495 
      137       138       139       140       141       142       143       144 
 7.915713 -5.538417 -4.138119 -4.550625 -2.519043 -5.682396 -5.486647  4.927082 
      145       146       147       148       149       150       151       152 
-6.019162 -4.131615 -2.879807 -4.568533 -5.769278 -3.139559  0.376957 -5.872123 
      153       154       155       156       157       158       159       160 
-5.608596 -5.698032 -5.837871 -0.008506  4.338537 -5.597215 -5.901898 -5.813553 
      161       162       163       164       165       166       167       168 
-5.875123 -5.044989 -5.987313  8.523766 -5.836558 -5.700165 -2.538585 -5.770068 
      169       170       171       172       173       174       175       176 
-2.265367  4.658476 -4.286293 -5.441896 -2.879424 -5.608177 -3.517357 -5.986593 
      177       178       179       180       181       182       183       184 
-5.194046  1.447071 -5.152485 -3.363513 -4.724920 -2.871090 -5.848980 -0.515476 
      185       186       187       188       189       190       191       192 
 1.961775 -2.138526 -3.252254 -5.390285  7.603120 -2.289831 -5.158960 -2.421881 
      193       194       195       196       197       198       199       200 
-0.922436 -5.170459 -5.609515 -5.720894 -5.223435 -2.645565 -5.862081 -5.869830 
      201       202       203       204       205       206       207       208 
-5.990701 -2.215167 -5.871816 -3.574385 -1.047319 -3.541048 -5.285180 -5.696745 
      209       210       211       212       213       214       215       216 
-5.796558 -2.817629  4.122941 -3.105367  4.177165 -2.673984 -5.230999  6.725791 
      217       218       219       220       221       222       223       224 
-5.343658 -1.978970 -0.230741 -2.364988 -2.582180 -2.638974 -6.023121  1.867903 
      225       226       227       228       229       230       231       232 
 7.556509 -5.167895 -5.028931 -4.484943 -4.168874 -5.990870 -3.364376 -4.712389 
      233       234       235       236       237       238       239       240 
-5.728302 -5.652488 -5.394056 -3.025364 -5.753760 -3.462406  7.157471 -4.703171 
      241       242       243       244       245       246       247       248 
-5.764758 -5.620641 -2.668696 -3.360656 -2.415106 -5.693728  5.944522 -0.810446 
      249       250       251       252       253       254       255       256 
-5.224485 -5.436323 -0.416604  0.657663  4.692019 -5.433794 -2.006968 -2.470598 
      257       258       259       260       261       262       263       264 
-0.962161 -4.100138 -2.487672 -5.785210 -2.269297 -4.644306 -0.790873 -2.759736 
      265       266       267       268       269       270       271       272 
-0.914629 -2.008547  5.449052 -5.816115 -4.013782 -5.718319 -6.024612 -5.797211 
      273       274       275       276       277       278       279       280 
-5.503270 10.427941 -5.552567 -5.875216 -2.288857 -4.764907 -2.791484 -5.028169 
      281       282       283       284       285       286       287       288 
-2.817750 -2.523922  2.019011 -3.300830  7.342869  4.985204  6.460195 -5.797998 
      289       290       291       292       293       294       295       296 
 6.113650 -1.550304 -3.028654 -6.021584 -3.792615 -5.549775 -5.581539  1.922900 
      297       298       299       300       301       302       303       304 
-5.856720  2.566497 -3.193938 -2.239732 -2.389355 -5.871350 -5.724273  8.459308 
      305       306       307       308       309       310       311       312 
-3.893771 -2.517095 -1.038995  3.907498 -5.615818  4.314727 -2.396046 -2.271380 
      313       314       315       316       317       318       319       320 
-5.053097 -4.465105 -2.265384 -4.130683 -2.642235 -3.478873 -5.693768 -5.873209 
      321       322       323       324       325       326       327       328 
-2.414034 -2.855175 -5.347736 -5.599219 -5.833340 -5.661733 -3.991755 -4.879862 
      329       330       331       332       333       334       335       336 
-4.998280 -1.929933 -2.138293 -4.764464  4.132990 -3.850126 -2.675392 -3.688796 
      337       338       339       340       341       342       343       344 
 0.209809 -3.368348 -5.108814 -2.180974 -2.424028 -5.849986 -5.406178 -4.137719 
      345       346       347       348       349       350       351       352 
-1.983958 -5.347680 -3.141664 -0.907080 -5.613488 -5.765632 -2.111944 -5.787584 
      353       354       355       356       357       358       359       360 
-2.346580  4.193709 -3.080170 -3.146969  6.327492 -2.930831  3.960770 -5.859146 
      361       362       363       364       365       366       367       368 
-2.907451 -3.219525 -3.631044 -3.257392 -0.060226 -3.325512 -0.375752 -5.240781 
      369       370       371       372       373       374       375       376 
 4.872050  1.892299  9.652724 -5.680333 -4.547576 -2.641622  5.732399 -5.153806 
      377       378       379       380       381       382       383       384 
-5.698837 -4.776598 -3.023458 -2.989604 -5.712994 -5.485987 -0.474146 -5.500856 
      385       386       387       388       389       390       391       392 
-5.567177 -2.151445 -2.370224 -3.727610 -5.109862 -5.743166 -5.850670 -2.442311 
      393       394       395       396       397       398       399       400 
-4.719423 -3.797231  1.589984 -3.141748 -2.732283 -2.912863 -5.546894 -1.967112 
      401       402       403       404       405       406       407       408 
-2.645122 -5.991082 -1.423914 -4.396991  1.833964 -3.108161  4.066335 -5.796431 
      409       410       411       412       413       414       415       416 
-5.448722 -1.176993 -3.030517 -5.533983 -5.041936 -3.845667 -5.232168 -5.907674 
      417       418       419       420       421       422       423       424 
-2.734455 -5.848698 -1.459724  3.958670 -4.976627 -5.841767  2.191879 -2.414500 
      425       426       427       428       429       430       431 
-2.365149 -4.221507 -3.080296 -5.697268 -5.872794 -5.757601  2.508728 

Probability Predictions

Suppose we wanted to see the predicted probabilities:

predict(ins_fit, new_data = ins, type = "prob")
# A tibble: 431 × 2
   .pred_no .pred_yes
      <dbl>     <dbl>
 1  0.773     0.227  
 2  0.419     0.581  
 3  0.994     0.00553
 4  0.995     0.00535
 5  0.0800    0.920  
 6  0.996     0.00391
 7  0.109     0.891  
 8  0.952     0.0479 
 9  0.00334   0.997  
10  0.996     0.00354
# ℹ 421 more rows

Plotting our Logisitic Regression

pred_probs <- predict(ins_fit, new_data = ins, type = "prob")

ins %>%
  mutate(
    pred_probs = pred_probs$.pred_yes
  ) %>%
  ggplot(mapping = aes(y = pred_probs, x = charges, color = smoker)) +
  geom_point(alpha = 0.75) +
  scale_x_continuous(labels = label_dollar()) +
  labs(x = "Charges", 
       y = "",
       title = "Predicted Probability of Being a Smoker based on Insurance Charges", 
       color = "Smoking Status")

Plotting our Logisitic Regression

Logistic Regression

How many did we get correct?

preds <- ins_fit %>% predict(ins)

ins <- ins %>%
  mutate(
    predicted_smoker = preds$.pred_class
  ) 

ins %>% count(predicted_smoker, smoker)
# A tibble: 4 × 3
  predicted_smoker smoker     n
  <fct>            <fct>  <int>
1 no               no       333
2 no               yes       24
3 yes              no        11
4 yes              yes       63

Logistic Regression

What percentage did we get correct?

ins %>%
  mutate(
    correct = (predicted_smoker == smoker)
  ) %>%
  count(correct) %>%
  mutate(
    pct = n/sum(n)
  )
# A tibble: 2 × 3
  correct     n    pct
  <lgl>   <int>  <dbl>
1 FALSE      35 0.0812
2 TRUE      396 0.919 

Logistic Regression

What percentage did we get correct?

ins %>%
  accuracy(truth = smoker,
           estimate = predicted_smoker)
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.919

Questions to ponder

  • What if we have a categorical variable where 99% of our values are Category A?

  • What if we have a categorical variable with more than 2 categories?

  • What if we want to do a transformation besides logistic?

  • Are there other ways to do classification besides these logistic regression and KNN?

Try it!

Open Activity-Classification.qmd again.

Select the best logistic regression model for predicting smoker status.

Report the cross-validated metrics - how do they compare to KNN?